import os
import numpy
from jqc import jqc_plot
from scipy import constants
from matplotlib import pyplot
from diatom import Hamiltonian
from matplotlib import gridspec
from matplotlib.patches import ConnectionPatch,Rectangle
from matplotlib.collections import LineCollection
from sympy.physics.wigner import wigner_3j,wigner_9j
from matplotlib.colors import LogNorm,LinearSegmentedColormap

def make_segments(x, y):
    '''
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection:
    an array of the form   numlines x (points per line) x 2 (x and y) array
    '''

    points = numpy.array([x, y]).T.reshape(-1, 1, 2)
    segments = numpy.concatenate([points[:-1], points[1:]], axis=1)

    return segments

def colorline(x, y, z=None, cmap=pyplot.get_cmap('copper'),
                norm=pyplot.Normalize(0.0, 1.0), linewidth=3, alpha=1.0,
                legend=False,ax=None):
    '''
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    '''
    if ax == None:
        ax = pyplot.gca()

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = numpy.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):#to check for numerical input -- this is a hack
        z = numpy.array([z])

    z = numpy.asarray(z)

    segments = make_segments(x, y)
    lc = LineCollection(segments, array=z, cmap=cmap, norm=norm,
                        linewidth=linewidth,zorder=1.25)

    ax.add_collection(lc)

    return lc

def dipolez(Nmax,d):
    ''' Generates the induced dipole moment for a Rigid rotor '''
    shape = numpy.sum(numpy.array([2*x+1 for x in range(0,Nmax+1)]))
    Dmat = numpy.zeros((shape,shape),dtype= numpy.complex)
    i =0
    j =0
    for N1 in range(0,Nmax+1):
        for M1 in range(N1,-(N1+1),-1):
            for N2 in range(0,Nmax+1):
                for M2 in range(N2,-(N2+1),-1):
                    Dmat[i,j]=d*numpy.sqrt((2*N1+1)*(2*N2+1))*(-1)**(M1)*\
                    wigner_3j(N1,1,N2,-M1,0,M2)*wigner_3j(N1,1,N2,0,0,0)
                    j+=1
            j=0
            i+=1
    return Dmat

#set up environment
jqc_plot.plot_style("normal")
grid = gridspec.GridSpec(2,5,width_ratios=[1,0.15,1,0.01,0.1])
cwd = os.path.dirname(os.path.abspath(__file__))
root = os.path.dirname(cwd)
fpath = cwd+"\\Data"

fig = pyplot.figure("EVERYTHING")

#set some constants

colour_dict_twk_blue = {
    "red" : [(0.0,244/255,244/255),
            (0.6,0,0),
            (1.0,0,0)] ,
    "green" : [(0.0,234/255,234/255),
            (0.6,70/255.0,70/255.0),
            (1.0,70/255,70/255)],
    "blue" : [(0.0,168/255,168/255),
            (0.6,127/255,127/255),
            (1.0,127/255,127/255)]
}

colour_dict_twk_blue_alpha = colour_dict_twk_blue.copy()
colour_dict_twk_blue_alpha['alpha'] = ((0.0, 0.0,0.0),
                   (0.25, .5, .5),
                   (0.5, 1., 1.),
                   (1.0, 1.0, 1.0))


RbCs_map_twk_blue = LinearSegmentedColormap("RbCs_map_tweak_blue",
                                                colour_dict_twk_blue_alpha)
pyplot.register_cmap(cmap=RbCs_map_twk_blue)
Nmax =5
I1 = 3/2
I2 = 7/2

h = constants.h

RbCs = Hamiltonian.RbCs

colours = jqc_plot.colours

indices =[]
for N in range(0,Nmax+1):
    for MN in range(N,-(N+1),-1):
        for MI1 in numpy.arange(I1,-(I1+1),-1):
            for MI2 in numpy.arange(I2,-(I2+1),-1):
                indices.append([N,MN,MI1,MI2])
#Start by plotting the Zeeman Structure and Highlighting mF = +5

axZa = fig.add_subplot(grid[0,0])
axZb = fig.add_subplot(grid[1,0],sharex=axZa)

fname = "Fig1_Zeeman"

Hyperfine_energy = numpy.genfromtxt(fpath+"\\Energies\\"+fname+".csv",
                                    delimiter=',')
try:
    mF = numpy.genfromtxt(fpath+"\\MF\\"+fname+".csv",delimiter=',')
    print("loaded mF")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\Output\\Sorted\\multithread_N5_states.npy")
    print("Calculating mF")
    Nvec,I1vec,I2vec = Hamiltonian.Generate_vecs(Nmax,I1,I2)
    F = Nvec+I1vec+I2vec
    Fz=F[2]
    mF = numpy.round(numpy.einsum('ik,ij,jk->k',
                Hyperfine_States[:,:,1],Fz,Hyperfine_States[:,:,1]).real)
    numpy.savetxt(fpath+"\\Output\\Sorted\\Ncalc5_mF.csv",mF,delimiter=',')
    print("Saved mF to:"+fpath+"\\Output\\Sorted\\Ncalc5_mF.csv")

B = Hyperfine_energy[0,:]*1e4
Hyperfine_energy = Hyperfine_energy[1:,:]

Nplotmax = 1
numberplot= numpy.sum([(2*x+1)*32 for x in range(Nplotmax+1)])
#Plot Zeeman Structure. Highlight states with mF = 5

loc =[[1.1,0.3],[1.1,0.25],[1.1,0.5],[1.1,0.75]]
Label=["0","0","1","2"] #labels in ascending energy

colours_fixed = [colours['red'],colours['grayblue'],colours['green'],
                colours['purple']]
k = 0

bbox = dict(boxstyle="round,pad=0", ec="none", fc="w", alpha=1.0)

l = numpy.where(B<181.5)[0][-1]

for i in range(numberplot):
    #iterate over all state vectors, use these to find mF and colour
    index = indices[i]

    if index[0] ==1:
        if mF[i] ==5:
            axZa.plot(B,1e-6*Hyperfine_energy[i,:]/h-980,
                    color=colours_fixed[k],zorder=2.0)
            axZa.text(loc[k][0],loc[k][1],Label[k],fontsize=15,
                        transform=axZa.transAxes,clip_on=False,bbox=bbox,
                        verticalalignment='center',
                        horizontalalignment='center',zorder=3)
            k+=1
        else:
            axZa.plot(B,1e-6*Hyperfine_energy[i,:]/h-980,
                    color=colours['sand'],zorder=1.0,alpha=0.5)
    elif index[0]==0:
        if mF[i]==5:
            axZb.plot(B,1e-6*Hyperfine_energy[i,:]/h,
                    color=colours_fixed[k],zorder=2.0)
            axZb.text(loc[k][0],loc[k][1],Label[k],fontsize=15,
                        transform=axZb.transAxes,clip_on=False,bbox=bbox,
                        verticalalignment='center',
                        horizontalalignment='center',zorder=3)
            k+=1
        else:
            axZb.plot(B,1e-6*Hyperfine_energy[i,:]/h,
                    color=colours['sand'],zorder=1.0,alpha=0.5)

axZa.tick_params(labelbottom=False)
axZb.set_xlabel("Magnetic Field,\n $B_z$ (G)")
axZb.set_xlim(0,181.5)

axZa.text(0.01,1.05,"+980 MHz",fontsize=15,clip_on=False,transform=axZa.transAxes)

axZb.set_xticks([0,100,181.5])
axZb.set_xticks([20,40,60,80,120,140,160],minor=True)
axZb.set_xticklabels(["0","100","181.5"])

axZa.set_ylim(-1.05,0.55)
axZb.set_ylim(-1.3,0.3)

axZa.set_yticks([-1,-0.5,0,0.5])

axZb.set_yticks([-1,-0.5,0,0.5])

axZa.text(25,-0.5,"$N=1$",transform=axZa.transData,fontsize=15)
axZb.text(25,-0.75,"$N=0$",transform=axZb.transData,fontsize=15)


axZa.text(-0.33,0,"Energy/$h$ (MHz)",transform=axZa.transAxes,rotation=90,
            horizontalalignment='center',verticalalignment='center')

axZb.text(0.01,0.05,"(a)",transform=axZb.transAxes,fontsize=18)

#Next is DC Stark Structure
axSa = fig.add_subplot(grid[0,2],sharey=axZa)
axSb = fig.add_subplot(grid[1,2],sharex=axSa,sharey=axZb)
#axSd.set_yscale('log')

fname = "Fig1_DC"

Hyperfine_energy = numpy.genfromtxt(fpath+"\\Energies\\"+fname+".csv",
                                    delimiter=',')

E = 1e-2*Hyperfine_energy[0,:]
Hyperfine_energy = Hyperfine_energy[1:,:]

k=0
gs = 0

try:
    d = numpy.genfromtxt(fpath+"\\TDM\\"+fname+".csv",delimiter=',',
    dtype=numpy.complex128)
    print("loaded TDM")

except IOError:
    Hyperfine_States = numpy.load(fpath+"\\Sorted\\multithread_N5_states.npy")
    print("Calculating dipole moments")
    dz = dipolez(Nmax,1)
    dz = numpy.kron(dz,numpy.kron(numpy.identity(int(2*I1+1)),
                    numpy.identity(int(2*I2+1))))
    d = numpy.einsum('ix,ij,jkx->kx',
    Hyperfine_States[:,0,:],dz,Hyperfine_States[:,32:numberplot+1,:])
    numpy.savetxt(fpath+"\\Sorted\\Ncalc5_N1_TDMz.csv",d,delimiter=',')
    print("Saved dipole moments to:"+fpath+"\\Sorted\\Ncalc5_N1_TDMz.csv")

try:
    mF = numpy.genfromtxt(fpath+"\\MF\\"+fname+".csv",delimiter=',')
    print("loaded mF")

except IOError:
    print("Calculating mF")
    Nvec,I1vec,I2vec = Hamiltonian.Generate_vecs(Nmax,I1,I2)
    F = Nvec+I1vec+I2vec
    Fz=F[2]
    mF = numpy.round(numpy.einsum('ik,ij,jk->k',
            Hyperfine_States[:,:,1],Fz,Hyperfine_States[:,:,1]).real)
    numpy.savetxt(fpath+"\\Sorted\\Ncalc5_mF.csv",mF,delimiter=',')
    print("Saved mF to:"+fpath+"\\Sorted\\Ncalc5_mF.csv")


for i in range(numberplot):
    index = indices[i]

    if index[0] ==1:
        TDM = numpy.abs(d[i-32,:])
        axSa.plot(E,1e-6*Hyperfine_energy[i,:]/h-980,
                    color=colours['sand'],alpha=0.5,zorder=1.0)

        cl = colorline(E,1e-6*Hyperfine_energy[i,:]/h-980,3*numpy.abs(TDM)**2,
                        cmap='RbCs_map_tweak_blue',norm=LogNorm(1e-2,1.0),
                        linewidth=2.0,ax=axSa)

        if mF[i]==5:
            cpA = ConnectionPatch((181.5,1e-6*Hyperfine_energy[i,0]/h-980),
                                (loc[k][0],loc[k][1]),
                                axesA=axZa,axesB=axZa,ls='--',lw=1,zorder=2,
                                coordsA=axZa.transData,coordsB=axZa.transAxes)
            axZa.add_patch(cpA)
            cpB = ConnectionPatch((loc[k][0],loc[k][1]),
                                (0,1e-6*Hyperfine_energy[i,0]/h-980),
                                axesA=axZa,axesB=axSa,ls='--',lw=1,zorder=2,
                                coordsA=axZa.transAxes,coordsB=axSa.transData)
            axZa.add_patch(cpB)
            k+=1
    elif index[0]==0:
        if i==0:
            axSb.plot(E,1e-6*Hyperfine_energy[i,:]/h,
                    color=colours['red'],zorder=2.0)
            cpA = ConnectionPatch((181.5,1e-6*Hyperfine_energy[i,0]/h),
                                (loc[k][0],loc[k][1]),
                                axesA=axZb,axesB=axZb,ls='--',lw=1.5,zorder=2,
                                coordsA=axZb.transData,coordsB=axZb.transAxes)
            axZb.add_patch(cpA)
            cpB = ConnectionPatch((loc[k][0],loc[k][1]),
                                (0,1e-6*Hyperfine_energy[i,0]/h),
                                axesA=axZb,axesB=axSb,ls='--',lw=1.5,zorder=2,
                                coordsA=axZb.transAxes,coordsB=axSb.transData)
            axZb.add_patch(cpB)
            k+=1
        else:
            axSb.plot(E,1e-6*Hyperfine_energy[i,:]/h,
                    color=colours['sand'],alpha=0.5,zorder=1.0)
axSa.set_xlim(0,125)

axSb.set_xlabel("Electric Field,\n $E_z$ (V$\\,$cm$^{-1}$)")

axSa.tick_params(labelbottom=False,labelleft=False)
axSb.tick_params(labelleft=False)

axSb.text(0.01,0.05,"(b)",transform=axSb.transAxes,fontsize=18)

#axSd.tick_params(labelleft=False)
#finalise figure
colax=fig.add_subplot(grid[:,-1])
colax.set_title("$z$",color=colours["blue"],fontsize=15)
cb = fig.colorbar(cl,cax = colax)
colax.set_ylabel("Relative Transition Strength")

fig.subplots_adjust(bottom =0.19,hspace=0.15,wspace=0.07,left=0.15,top=0.93,
                    right=0.86)

pyplot.savefig(cwd+"\\Stark_Zeeman_low.pdf")
#pyplot.savefig("OUT\\Stark_Zeeman_low.png")

pyplot.show()
